import torch

outputs = torch.tensor([[0.1, 0.2],
                        [0.05, 0.4]])

# argmax输出最大值
# 纵向
print(outputs.argmax(0))

# 横向
print(outputs.argmax(1))

preds = outputs.argmax(1)
target = torch.tensor([0,1])
# 输出应为1
print((preds == target).sum().item())